# http://proceedings.mlr.press/v101/huang19a/huang19a.pdf
# https://www.researchgate.net/publication/220875351_Generative_Models_for_Labeling_Multi-object_Configurations_in_Images
# https://www.tensorflow.org/datasets/catalog/open_images_v4
# Auto-Encoding Progressive Generative Adversarial Networks For 3D Multi Object Scenes
# TODO
# for data set kitt (as AD case study) - for the built model
# 1. report model loss for validation dataset - Done
# 2. visualize reconstructed images - Done
# 3. Grid search (K, cov type) for gaussian mixture log p comparison (or baysian parameter optimization) - SKIP
# reason: nead to focus on core idea - GM is good other than G in Autonomous driving on a simplified case
# 4. read about inf Gaussian mixture https://www.seas.harvard.edu/courses/cs281/papers/rasmussen-1999a.pdf
%config Completer.use_jedi = False
from ipywidgets import IntProgress
import matplotlib.pyplot as plt
from tensorflow.keras import layers, losses
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import logging
import tensorflow_datasets as tfds
import pandas as pd
from tqdm import tqdm_notebook as tqdm
from sklearn.mixture import GaussianMixture
import os
seed = 1
np.random.seed(1)
tf.random.set_seed(1)
batch_size = 32
epochs = 10
dataset_name = 'wider_face'
if dataset_name == 'bdd100k':
train_ds = tf.keras.preprocessing\
.image_dataset_from_directory(directory='../data/bdd100k/images/10k/train1/',batch_size=batch_size)# train
test_ds = tf.keras.preprocessing\
.image_dataset_from_directory(directory='../data/bdd100k/images/10k/test1/',batch_size=batch_size) # test
validation_ds = tf.keras.preprocessing\
.image_dataset_from_directory(directory='../data/bdd100k/images/10k/val1/',batch_size=batch_size) # validation
else:
train_ds,test_ds,validation_ds = tfds.load(name=dataset_name,split=['train', 'test','validation']\
,as_supervised=False,download=False)
if dataset_name == 'bdd100k':
dims = [x[0].get_shape().as_list() for x in train_ds]
dims_df= pd.DataFrame.from_records(data=dims,columns=['batch','height','width','depth'])
else:
dims = [x['image'].get_shape().as_list() for x in train_ds]
dims_df= pd.DataFrame.from_records(data=dims,columns=['height','width','depth'])
dims_df.describe()
| height | width | depth | |
|---|---|---|---|
| count | 12880.000000 | 12880.0 | 12880.0 |
| mean | 888.309627 | 1024.0 | 3.0 |
| std | 350.513446 | 0.0 | 0.0 |
| min | 171.000000 | 1024.0 | 3.0 |
| 25% | 682.000000 | 1024.0 | 3.0 |
| 50% | 760.000000 | 1024.0 | 3.0 |
| 75% | 1024.000000 | 1024.0 | 3.0 |
| max | 9108.000000 | 1024.0 | 3.0 |
height = 2**(int(np.log2(min(dims_df['height']))))
width = 2**(int(np.log2(min(dims_df['width']))))
height,width = min(height,width),min(height,width)
height,width
(128, 128)
if dataset_name == 'bdd100k':
train_ds = train_ds.map(lambda x0,x1: x0/255.)
test_ds = test_ds.map(lambda x0,x1: x0/255.)
validation_ds = validation_ds.map(lambda x0,x1: x0/255.)
else:
train_ds = train_ds.map(lambda x: tf.image.resize(images=tf.cast(x['image'],dtype=tf.float32)/255.,\
size=[height,width]))
train_ds = train_ds.batch(batch_size,drop_remainder=True)
###
test_ds = test_ds.map(lambda x: tf.image.resize(tf.cast(x['image'],dtype=tf.float32)/255.,\
size=[height,width]))
test_ds = test_ds.batch(batch_size,drop_remainder=True)
###
validation_ds = validation_ds.map(lambda x: tf.image.resize(tf.cast(x['image'],dtype=tf.float32)/255.\
,size=[height,width]))
validation_ds = validation_ds.batch(batch_size,drop_remainder=True)
###
train_ds_double_zipped = tf.data.Dataset.zip(datasets=(train_ds,train_ds))
test_ds_double_zipped = tf.data.Dataset.zip(datasets=(test_ds,test_ds))
validation_ds_double_zipped = tf.data.Dataset.zip(datasets=(validation_ds,validation_ds))
latent_dim = 4096
class CAE(tf.keras.Model):
"""Convolutional variational autoencoder."""
def __init__(self, latent_dim):
super(CAE, self).__init__()
self.latent_dim = latent_dim
self.logger = logging.getLogger('CAE')
self.encoder = tf.keras.Sequential(name='encoder',layers=\
[
tf.keras.layers.InputLayer(input_shape=(height, width, 3)),
tf.keras.layers.Conv2D(
filters=32, kernel_size=3, strides=(2, 2), activation='relu'),
tf.keras.layers.Conv2D(
filters=64, kernel_size=3, strides=(2, 2), activation='relu'),
tf.keras.layers.Flatten(),
# No activation
tf.keras.layers.Dense(latent_dim),
]
)
self.decoder = tf.keras.Sequential(name='decoder',layers=\
[
tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
tf.keras.layers.Dense(units=int(height/4) * int(width/4) * 32, activation=tf.nn.relu),
tf.keras.layers.Reshape(target_shape=(int(height/4), int(width/4), 32)),
tf.keras.layers.Conv2DTranspose(
filters=64, kernel_size=3, strides=2, padding='same',
activation='relu'),
tf.keras.layers.Conv2DTranspose(
filters=32, kernel_size=3, strides=2, padding='same',
activation='relu'),
# No activation
tf.keras.layers.Conv2DTranspose(
filters=3, kernel_size=3, strides=1, padding='same'),
]
)
self.encoder.summary()
self.decoder.summary()
def call(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
cae = CAE(latent_dim)
cae.compile(optimizer='adam', loss=losses.MeanSquaredError())
Model: "encoder" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d (Conv2D) (None, 63, 63, 32) 896 _________________________________________________________________ conv2d_1 (Conv2D) (None, 31, 31, 64) 18496 _________________________________________________________________ flatten (Flatten) (None, 61504) 0 _________________________________________________________________ dense (Dense) (None, 4096) 251924480 ================================================================= Total params: 251,943,872 Trainable params: 251,943,872 Non-trainable params: 0 _________________________________________________________________ Model: "decoder" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_1 (Dense) (None, 32768) 134250496 _________________________________________________________________ reshape (Reshape) (None, 32, 32, 32) 0 _________________________________________________________________ conv2d_transpose (Conv2DTran (None, 64, 64, 64) 18496 _________________________________________________________________ conv2d_transpose_1 (Conv2DTr (None, 128, 128, 32) 18464 _________________________________________________________________ conv2d_transpose_2 (Conv2DTr (None, 128, 128, 3) 867 ================================================================= Total params: 134,288,323 Trainable params: 134,288,323 Non-trainable params: 0 _________________________________________________________________
model_file_path = f'./models/cae_dataset_{dataset_name}_z_dim_{latent_dim}'
print(f'model path = {model_file_path}')
model path = ./models/cae_dataset_wider_face_z_dim_4096
if os.path.exists(model_file_path):
print('loading saved model')
cae = tf.keras.models.load_model(filepath=model_file_path)
else:
print('building model')
# use checkpoints to save model fitting progress
# https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint
checkpoint_filepath = './checkpoint'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='val_loss',
mode='max',
save_best_only=True)
# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
cae.fit(x=train_ds_double_zipped,validation_data=test_ds_double_zipped,epochs=epochs,\
callbacks=[model_checkpoint_callback])
# The model weights (that are considered the best) are loaded into the model.
cae.load_weights(checkpoint_filepath)
print('saving model')
cae.save(filepath=model_file_path)
loading saved model
# create valdation dataset tensor
for e in validation_ds.take(1):
initial_state = tf.zeros(dtype=tf.float32,shape=e.shape)
validation_ds_tensor = validation_ds.\
reduce(initial_state=initial_state,reduce_func=lambda x,y: tf.concat(values=[x,y],axis=0))
validation_ds_tensor = validation_ds_tensor[batch_size:] # drop dummy initial state
# calculate loss, can be compare over different dataset due to data scaling from 0 to 1
y_predicted = cae.predict(validation_ds)
cae_loss = cae.loss(y_pred=y_predicted,y_true=validation_ds_tensor).numpy()
print(f'CAE loss for dataset {dataset_name} = {np.round(cae_loss,4)}')
CAE loss for dataset wider_face = 0.017400000244379044
# plot decoded images
for batch in validation_ds.take(1):
z = cae.encoder(batch).numpy()
decoded_imgs = cae.decoder(z).numpy()
for i in range(batch.shape[0]):
fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.imshow(batch[i])
ax2.imshow(decoded_imgs[i])
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/ipykernel_launcher.py:8: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
# getting z tensor
z_tensor = None
inf_or_unknown_cardinality = ((test_ds.cardinality()==tf.data.INFINITE_CARDINALITY)\
or (test_ds.cardinality() == tf.data.UNKNOWN_CARDINALITY)).numpy()
batches = test_ds.cardinality().numpy() if not inf_or_unknown_cardinality else 500
with tqdm(total=batches) as pbar:
for batch in test_ds.take(batches):
z = cae.encoder(batch).numpy()
if z_tensor is None:
z_tensor = tf.convert_to_tensor(z)
else:
z_tensor = tf.concat([z_tensor,tf.convert_to_tensor(z)],axis=0)
pbar.update(1)
#print(f'z shape {z.shape}')
# decoded_imgs = cae.decoder(z).numpy()
# #print(f'decoded images shape {decoded_imgs[0].shape}')
# plt.imshow(batch[0])
# plt.show()
# plt.imshow(decoded_imgs[0])
# plt.show()
z_tensor.shape
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/ipykernel_launcher.py:8: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0 Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
TensorShape([16096, 4096])
z_np= z_tensor.numpy()
n_z = z_np.shape[0]
n_z_train = int(0.8*n_z)
z_train = z_np[:n_z_train]
z_test = z_np[n_z_train:]
random_state = 1
reg_covar = 0.1
cov_type = 'diag'
print(f"""For Dateset "{dataset_name}" Calculating relative difference of log likelihood """)
print(f'Latent_dim = {latent_dim}, Gaussiam Mixture covariance type = {cov_type} and reg_covar = {reg_covar} ')
print('############################ ')
g_fit = GaussianMixture(n_components=1,covariance_type=cov_type,random_state=1,reg_covar=reg_covar).fit(z_train)
logp_g = g_fit.score(X=z_test)
for k in [10,20,50,70,80,100,200]:
try:
gm_fit = GaussianMixture(n_components=k,covariance_type=cov_type,random_state=random_state,\
reg_covar=reg_covar).fit(z_train)
logp_gm = gm_fit.score(X=z_test)
rel_diff_logps = (logp_gm- logp_g) / np.abs(logp_g)
print(f'logp Gaussin Mixture with k = {k} = {logp_gm} ')
print(f'logp Gaussian Diagonal = {logp_g} ')
print(f'At k = {k} , rel_diff for logps = {rel_diff_logps} ')
print('############## ')
except Exception as e:
print(f'Catched expection {e} ')
For Dateset "wider_face" Calculating relative difference of log likelihood Latent_dim = 4096, Gaussiam Mixture covariance type = diag and reg_covar = 0.1 ############################
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/mixture/_base.py:269: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data. % (init + 1), ConvergenceWarning)
logp Gaussin Mixture with k = 10 = -155.9958668831587 logp Gaussian Diagonal = -176.16444328638678 At k = 10 , rel_diff for logps = 0.11448721448539113 ##############
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/mixture/_base.py:269: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data. % (init + 1), ConvergenceWarning)
logp Gaussin Mixture with k = 20 = -151.23816837550027 logp Gaussian Diagonal = -176.16444328638678 At k = 20 , rel_diff for logps = 0.14149435859973403 ############## logp Gaussin Mixture with k = 50 = -137.3661035053167 logp Gaussian Diagonal = -176.16444328638678 At k = 50 , rel_diff for logps = 0.22023933466526194 ##############
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/mixture/_base.py:269: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data. % (init + 1), ConvergenceWarning)
logp Gaussin Mixture with k = 70 = -145.83991075448344 logp Gaussian Diagonal = -176.16444328638678 At k = 70 , rel_diff for logps = 0.17213764574844084 ##############
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/mixture/_base.py:269: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data. % (init + 1), ConvergenceWarning)
logp Gaussin Mixture with k = 80 = -137.53119805187438 logp Gaussian Diagonal = -176.16444328638678 At k = 80 , rel_diff for logps = 0.2193021730935065 ##############
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/mixture/_base.py:269: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data. % (init + 1), ConvergenceWarning)
logp Gaussin Mixture with k = 100 = -137.11314433122263 logp Gaussian Diagonal = -176.16444328638678 At k = 100 , rel_diff for logps = 0.2216752610609355 ##############
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/sklearn/mixture/_base.py:269: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data. % (init + 1), ConvergenceWarning)
logp Gaussin Mixture with k = 200 = -123.6518884587481 logp Gaussian Diagonal = -176.16444328638678 At k = 200 , rel_diff for logps = 0.2980882739331804 ##############
For Dateset "kitti" Calculating relative difference of log likelihood
Latent_dim = 64, Gaussiam Mixture covariance type = diag and reg_covar = 0.1
logp Gaussin Mixture with k = 10 = -75.45441572317257
logp Gaussian Diagonal = -115.44620189247304
At k = 10 , rel_diff for logps = 0.346410583576833
logp Gaussin Mixture with k = 20 = -69.64451526564827
logp Gaussian Diagonal = -115.44620189247304
At k = 20 , rel_diff for logps = 0.39673619292806717
logp Gaussin Mixture with k = 50 = -55.93752814502553
logp Gaussian Diagonal = -115.44620189247304
At k = 50 , rel_diff for logps = 0.5154667089253753
logp Gaussin Mixture with k = 70 = -50.88304211844375
logp Gaussian Diagonal = -115.44620189247304
At k = 70 , rel_diff for logps = 0.559248885763809
logp Gaussin Mixture with k = 80 = -52.1276428852566
logp Gaussian Diagonal = -115.44620189247304
At k = 80 , rel_diff for logps = 0.5484681000262923
logp Gaussin Mixture with k = 100 = -54.52694195135895
logp Gaussian Diagonal = -115.44620189247304
At k = 100 , rel_diff for logps = 0.5276852676180243
logp Gaussin Mixture with k = 200 = -68.2897651271402
logp Gaussian Diagonal = -115.44620189247304
At k = 200 , rel_diff for logps = 0.40847109729304476
For Dateset VOC Calculating relative difference of log likelihood
logp Gaussin Mixture with k = 10 = 3239.302399019124
logp Gaussian Diagonal = 2790.16976645871
At k = 10 , rel_diff for logps = 0.16096964348174925
logp Gaussin Mixture with k = 20 = 3334.117578802105
logp Gaussian Diagonal = 2790.16976645871
At k = 20 , rel_diff for logps = 0.19495151115259016
logp Gaussin Mixture with k = 50 = 3411.8171131494596
logp Gaussian Diagonal = 2790.16976645871
At k = 50 , rel_diff for logps = 0.22279911214138984
logp Gaussin Mixture with k = 70 = 3429.9682203595194
logp Gaussian Diagonal = 2790.16976645871
At k = 70 , rel_diff for logps = 0.22930448949450236
logp Gaussin Mixture with k = 100 = 3428.6442835155385
logp Gaussian Diagonal = 2790.16976645871
At k = 100 , rel_diff for logps = 0.22882998903223795
logp Gaussin Mixture with k = 200 = 3415.867706729245
logp Gaussian Diagonal = 2790.16976645871
At k = 200 , rel_diff for logps = 0.22425084946163418
For Dateset wider_face Calculating relative difference of log likelihood
logp Gaussin Mixture with k = 10 = 1726.1243177708377
logp Gaussian Diagonal = 1561.8860242887904
At k = 10 , rel_diff for logps = 0.10515382744194393
logp Gaussin Mixture with k = 20 = 1751.7959812169918
logp Gaussian Diagonal = 1561.8860242887904
At k = 20 , rel_diff for logps = 0.1215901506095347
logp Gaussin Mixture with k = 50 = 1776.35932923814
logp Gaussian Diagonal = 1561.8860242887904
At k = 50 , rel_diff for logps = 0.13731687307145898
Catched expection Fitting the mixture model failed because some components have ill-defined empirical covariance (for instance caused by singleton or collapsed samples). Try to decrease the number of components, or increase reg_covar.
Catched expection Fitting the mixture model failed because some components have ill-defined empirical covariance (for instance caused by singleton or collapsed samples). Try to decrease the number of components, or increase reg_covar.
Catched expection Fitting the mixture model failed because some components have ill-defined empirical covariance (for instance caused by singleton or collapsed samples). Try to decrease the number of components, or increase reg_covar.
For Dateset "mnist" Calculating relative difference of log likelihood
logp Gaussin Mixture with k = 10 = 7841.956234385311
logp Gaussian Diagonal = 7377.442821602376
At k = 10 , rel_diff for logps = 0.06296401395653828
logp Gaussin Mixture with k = 20 = 7988.197403977028
logp Gaussian Diagonal = 7377.442821602376
At k = 20 , rel_diff for logps = 0.08278675919876485
logp Gaussin Mixture with k = 50 = 8243.583374395243
logp Gaussian Diagonal = 7377.442821602376
At k = 50 , rel_diff for logps = 0.11740389912025657
logp Gaussin Mixture with k = 70 = 8305.125097353703
logp Gaussian Diagonal = 7377.442821602376
At k = 70 , rel_diff for logps = 0.12574577644098017
logp Gaussin Mixture with k = 100 = 8385.913961674707
logp Gaussian Diagonal = 7377.442821602376
At k = 100 , rel_diff for logps = 0.1366965714894272
logp Gaussin Mixture with k = 200 = 8500.999476912375
logp Gaussian Diagonal = 7377.442821602376
At k = 200 , rel_diff for logps = 0.15229622004253812
For Dateset "cifar100" Calculating relative difference of log likelihood
logp Gaussin Mixture with k = 10 = 4348.738804362613
logp Gaussian Diagonal = 4039.8092947584987
At k = 10 , rel_diff for logps = 0.07647131016925454
logp Gaussin Mixture with k = 20 = 4384.332721210587
logp Gaussian Diagonal = 4039.8092947584987
At k = 20 , rel_diff for logps = 0.08528210153362795
logp Gaussin Mixture with k = 50 = 4422.78147918955
logp Gaussian Diagonal = 4039.8092947584987
At k = 50 , rel_diff for logps = 0.09479957010048548
logp Gaussin Mixture with k = 70 = 4428.119142994677
logp Gaussian Diagonal = 4039.8092947584987
At k = 70 , rel_diff for logps = 0.09612083638205283
logp Gaussin Mixture with k = 100 = 4425.925444793071
logp Gaussian Diagonal = 4039.8092947584987
At k = 100 , rel_diff for logps = 0.09557781614482229
logp Gaussin Mixture with k = 200 = 4424.68448280697
logp Gaussian Diagonal = 4039.8092947584987
At k = 200 , rel_diff for logps = 0.09527063283601832